package com.nulldev.util.internal.backport.httpclient_rw.impl.websocket;

import static com.nulldev.util.internal.backport.httpclient_rw.impl.websocket.TransportImpl.ChannelState.*;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.channels.SelectionKey;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.function.Supplier;

import com.nulldev.util.internal.backport.concurrency9.concurrent.CompletableFuture;
import com.nulldev.util.internal.backport.httpclient_rw.impl.common.Demand;
import com.nulldev.util.internal.backport.httpclient_rw.impl.common.Logger;
import com.nulldev.util.internal.backport.httpclient_rw.impl.common.MinimalFuture;
import com.nulldev.util.internal.backport.httpclient_rw.impl.common.SequentialScheduler;
import com.nulldev.util.internal.backport.httpclient_rw.impl.common.SequentialScheduler.CompleteRestartableTask;
import com.nulldev.util.internal.backport.httpclient_rw.impl.common.Utils;

public class TransportImpl implements Transport {

	// -- Debugging infrastructure --

	private static final Logger debug = Utils.getWebSocketLogger("[Transport]"::toString, Utils.DEBUG_WS);

	/* Used for correlating enters to and exists from a method */
	private final AtomicLong counter = new AtomicLong();

	private final SequentialScheduler sendScheduler = new SequentialScheduler(new SendTask());

	private final MessageQueue queue;
	private final MessageEncoder encoder = new MessageEncoder();
	/* A reusable buffer for writing, initially with no remaining bytes */
	private final ByteBuffer dst = (ByteBuffer) createWriteBuffer().position(0).limit(0);
	/* This array is created once for gathering writes accepted by RawChannel */
	private final ByteBuffer[] dstArray = new ByteBuffer[]
		{ dst };
	private final MessageStreamConsumer messageConsumer;
	private final MessageDecoder decoder;
	private final Frame.Reader reader = new Frame.Reader();

	private final Demand demand = new Demand();
	private final SequentialScheduler receiveScheduler;
	private final RawChannel channel;
	private final Object closeLock = new Object();
	private final RawChannel.RawEvent writeEvent = new WriteEvent();
	private final RawChannel.RawEvent readEvent = new ReadEvent();
	private final AtomicReference<ChannelState> writeState = new AtomicReference<>(UNREGISTERED);
	private ByteBuffer data;
	private volatile ChannelState readState = UNREGISTERED;
	private boolean inputClosed;
	private boolean outputClosed;

	public TransportImpl(MessageQueue queue, MessageStreamConsumer consumer, RawChannel channel) {
		this.queue = queue;
		this.messageConsumer = consumer;
		this.channel = channel;
		this.decoder = new MessageDecoder(this.messageConsumer);
		this.data = channel.initialByteBuffer();
		// To ensure the initial non-final `data` will be visible
		// (happens-before) when `readEvent.handle()` invokes `receiveScheduler`
		// the following assignment is done last:
		receiveScheduler = new SequentialScheduler(new ReceiveTask());
	}

	private ByteBuffer createWriteBuffer() {
		String name = "jdk.httpclient.websocket.writeBufferSize";
		int capacity = Utils.getIntegerNetProperty(name, 16384);
		if (debug.on()) {
			debug.log("write buffer capacity %s", capacity);
		}

		// TODO (optimization?): allocateDirect if SSL?
		return ByteBuffer.allocate(capacity);
	}

	private boolean write() throws IOException {
		if (debug.on()) {
			debug.log("writing to the channel");
		}
		long count = channel.write(dstArray, 0, dstArray.length);
		if (debug.on()) {
			debug.log("%s bytes written", count);
		}
		for (ByteBuffer b : dstArray) {
			if (b.hasRemaining()) {
				return false;
			}
		}
		return true;
	}

	@Override
	public <T> CompletableFuture<T> sendText(CharSequence message, boolean isLast, T attachment, BiConsumer<? super T, ? super Throwable> action) {
		long id = 0;
		if (debug.on()) {
			id = counter.incrementAndGet();
			debug.log("enter send text %s message.length=%s last=%s", id, message.length(), isLast);
		}
		// TODO (optimization?):
		// These sendXXX methods might be a good place to decide whether or not
		// we can write straight ahead, possibly returning null instead of
		// creating a CompletableFuture

		// Even if the text is already CharBuffer, the client will not be happy
		// if they discover the position is changing. So, no instanceof
		// cheating, wrap always.
		CharBuffer text = CharBuffer.wrap(message);
		MinimalFuture<T> f = new MinimalFuture<>();
		try {
			queue.addText(text, isLast, attachment, action, f);
			sendScheduler.runOrSchedule();
		} catch (IOException e) {
			action.accept(null, e);
			f.completeExceptionally(e);
		}
		if (debug.on()) {
			debug.log("exit send text %s returned %s", id, f);
		}
		return f;
	}

	@Override
	public <T> CompletableFuture<T> sendBinary(ByteBuffer message, boolean isLast, T attachment, BiConsumer<? super T, ? super Throwable> action) {
		long id = 0;
		if (debug.on()) {
			id = counter.incrementAndGet();
			debug.log("enter send binary %s message.remaining=%s last=%s", id, message.remaining(), isLast);
		}
		MinimalFuture<T> f = new MinimalFuture<>();
		try {
			queue.addBinary(message, isLast, attachment, action, f);
			sendScheduler.runOrSchedule();
		} catch (IOException e) {
			action.accept(null, e);
			f.completeExceptionally(e);
		}
		if (debug.on()) {
			debug.log("exit send binary %s returned %s", id, f);
		}
		return f;
	}

	@Override
	public <T> CompletableFuture<T> sendPing(ByteBuffer message, T attachment, BiConsumer<? super T, ? super Throwable> action) {
		long id = 0;
		if (debug.on()) {
			id = counter.incrementAndGet();
			debug.log("enter send ping %s message.remaining=%s", id, message.remaining());
		}
		MinimalFuture<T> f = new MinimalFuture<>();
		try {
			queue.addPing(message, attachment, action, f);
			sendScheduler.runOrSchedule();
		} catch (IOException e) {
			action.accept(null, e);
			f.completeExceptionally(e);
		}
		if (debug.on()) {
			debug.log("exit send ping %s returned %s", id, f);
		}
		return f;
	}

	@Override
	public <T> CompletableFuture<T> sendPong(ByteBuffer message, T attachment, BiConsumer<? super T, ? super Throwable> action) {
		long id = 0;
		if (debug.on()) {
			id = counter.incrementAndGet();
			debug.log("enter send pong %s message.remaining=%s", id, message.remaining());
		}
		MinimalFuture<T> f = new MinimalFuture<>();
		try {
			queue.addPong(message, attachment, action, f);
			sendScheduler.runOrSchedule();
		} catch (IOException e) {
			action.accept(null, e);
			f.completeExceptionally(e);
		}
		if (debug.on()) {
			debug.log("exit send pong %s returned %s", id, f);
		}
		return f;
	}

	@Override
	public <T> CompletableFuture<T> sendPong(Supplier<? extends ByteBuffer> message, T attachment, BiConsumer<? super T, ? super Throwable> action) {
		long id = 0;
		if (debug.on()) {
			id = counter.incrementAndGet();
			debug.log("enter send pong %s supplier=%s", id, message);
		}
		MinimalFuture<T> f = new MinimalFuture<>();
		try {
			queue.addPong(message, attachment, action, f);
			sendScheduler.runOrSchedule();
		} catch (IOException e) {
			action.accept(null, e);
			f.completeExceptionally(e);
		}
		if (debug.on()) {
			debug.log("exit send pong %s returned %s", id, f);
		}
		return f;
	}

	@Override
	public <T> CompletableFuture<T> sendClose(int statusCode, String reason, T attachment, BiConsumer<? super T, ? super Throwable> action) {
		long id = 0;
		if (debug.on()) {
			id = counter.incrementAndGet();
			debug.log("enter send close %s statusCode=%s reason.length=%s", id, statusCode, reason.length());
		}
		MinimalFuture<T> f = new MinimalFuture<>();
		try {
			queue.addClose(statusCode, CharBuffer.wrap(reason), attachment, action, f);
			sendScheduler.runOrSchedule();
		} catch (IOException e) {
			action.accept(null, e);
			f.completeExceptionally(e);
		}
		if (debug.on()) {
			debug.log("exit send close %s returned %s", id, f);
		}
		return f;
	}

	@Override
	public void request(long n) {
		if (debug.on()) {
			debug.log("request %s", n);
		}
		if (demand.increase(n)) {
			receiveScheduler.runOrSchedule();
		}
	}

	@Override
	public void acknowledgeReception() {
		boolean decremented = demand.tryDecrement();
		if (!decremented) {
			throw new InternalError();
		}
	}

	@Override
	public void closeOutput() throws IOException {
		if (debug.on()) {
			debug.log("closeOutput");
		}
		synchronized (closeLock) {
			if (!outputClosed) {
				outputClosed = true;
				try {
					channel.shutdownOutput();
				} finally {
					if (inputClosed) {
						channel.close();
					}
				}
			}
		}
		writeState.set(CLOSED);
		sendScheduler.runOrSchedule();
	}

	/*
	 * Permanently stops reading from the channel and delivering messages regardless
	 * of the current demand and data availability.
	 */
	@Override
	public void closeInput() throws IOException {
		if (debug.on()) {
			debug.log("closeInput");
		}
		synchronized (closeLock) {
			if (!inputClosed) {
				inputClosed = true;
				try {
					receiveScheduler.stop();
					channel.shutdownInput();
				} finally {
					if (outputClosed) {
						channel.close();
					}
				}
			}
		}
	}

	/* Common states for send and receive tasks */
	enum ChannelState {
		UNREGISTERED,
		AVAILABLE,
		WAITING,
		CLOSED,
	}

	@SuppressWarnings(
		{ "rawtypes" })
	private class SendTask extends CompleteRestartableTask {

		private final MessageQueue.QueueCallback<Boolean, IOException> encodingCallback = new MessageQueue.QueueCallback<Boolean, IOException>() {

			@Override
			public <T> Boolean onText(CharBuffer message, boolean isLast, T attachment, BiConsumer<? super T, ? super Throwable> action,
					CompletableFuture<? super T> future) throws IOException {
				return encoder.encodeText(message, isLast, dst);
			}

			@Override
			public <T> Boolean onBinary(ByteBuffer message, boolean isLast, T attachment, BiConsumer<? super T, ? super Throwable> action,
					CompletableFuture<? super T> future) throws IOException {
				return encoder.encodeBinary(message, isLast, dst);
			}

			@Override
			public <T> Boolean onPing(ByteBuffer message, T attachment, BiConsumer<? super T, ? super Throwable> action, CompletableFuture<? super T> future)
					throws IOException {
				return encoder.encodePing(message, dst);
			}

			@Override
			public <T> Boolean onPong(ByteBuffer message, T attachment, BiConsumer<? super T, ? super Throwable> action, CompletableFuture<? super T> future)
					throws IOException {
				return encoder.encodePong(message, dst);
			}

			@Override
			public <T> Boolean onPong(Supplier<? extends ByteBuffer> message, T attachment, BiConsumer<? super T, ? super Throwable> action,
					CompletableFuture<? super T> future) throws IOException {
				return encoder.encodePong(message.get(), dst);
			}

			@Override
			public <T> Boolean onClose(int statusCode, CharBuffer reason, T attachment, BiConsumer<? super T, ? super Throwable> action,
					CompletableFuture<? super T> future) throws IOException {
				return encoder.encodeClose(statusCode, reason, dst);
			}

			@Override
			public Boolean onEmpty() {
				return false;
			}
		};

		/* Whether the task sees the current head message for first time */
		private boolean firstPass = true;
		/* Whether the message has been fully encoded */
		private boolean encoded;

		// -- Current message completion communication fields --

		private Object attachment;
		private BiConsumer action;
		private CompletableFuture future;
		private final MessageQueue.QueueCallback<Boolean, RuntimeException>
		/* If there is a message, loads its completion communication fields */
		loadCallback = new MessageQueue.QueueCallback<Boolean, RuntimeException>() {

			@Override
			public <T> Boolean onText(CharBuffer message, boolean isLast, T attachment, BiConsumer<? super T, ? super Throwable> action,
					CompletableFuture<? super T> future) {
				SendTask.this.attachment = attachment;
				SendTask.this.action = action;
				SendTask.this.future = future;
				return true;
			}

			@Override
			public <T> Boolean onBinary(ByteBuffer message, boolean isLast, T attachment, BiConsumer<? super T, ? super Throwable> action,
					CompletableFuture<? super T> future) {
				SendTask.this.attachment = attachment;
				SendTask.this.action = action;
				SendTask.this.future = future;
				return true;
			}

			@Override
			public <T> Boolean onPing(ByteBuffer message, T attachment, BiConsumer<? super T, ? super Throwable> action, CompletableFuture<? super T> future) {
				SendTask.this.attachment = attachment;
				SendTask.this.action = action;
				SendTask.this.future = future;
				return true;
			}

			@Override
			public <T> Boolean onPong(ByteBuffer message, T attachment, BiConsumer<? super T, ? super Throwable> action, CompletableFuture<? super T> future) {
				SendTask.this.attachment = attachment;
				SendTask.this.action = action;
				SendTask.this.future = future;
				return true;
			}

			@Override
			public <T> Boolean onPong(Supplier<? extends ByteBuffer> message, T attachment, BiConsumer<? super T, ? super Throwable> action,
					CompletableFuture<? super T> future) {
				SendTask.this.attachment = attachment;
				SendTask.this.action = action;
				SendTask.this.future = future;
				return true;
			}

			@Override
			public <T> Boolean onClose(int statusCode, CharBuffer reason, T attachment, BiConsumer<? super T, ? super Throwable> action,
					CompletableFuture<? super T> future) {
				SendTask.this.attachment = attachment;
				SendTask.this.action = action;
				SendTask.this.future = future;
				return true;
			}

			@Override
			public Boolean onEmpty() {
				return false;
			}
		};

		@Override
		public void run() {
			// Could have been only called in one of the following cases:
			// (a) A message has been added to the queue
			// (b) The channel is ready for writing
			if (debug.on()) {
				debug.log("enter send task");
			}
			while (!queue.isEmpty()) {
				try {
					if (dst.hasRemaining()) {
						if (debug.on()) {
							debug.log("%s bytes remaining in buffer %s", dst.remaining(), dst);
						}
						// The previous part of the binary representation of the
						// message hasn't been fully written
						if (!tryCompleteWrite()) {
							break;
						}
					} else if (!encoded) {
						if (firstPass) {
							firstPass = false;
							queue.peek(loadCallback);
							if (debug.on()) {
								debug.log("load message");
							}
						}
						dst.clear();
						encoded = queue.peek(encodingCallback);
						dst.flip();
						if (!tryCompleteWrite()) {
							break;
						}
					} else {
						// All done, remove and complete
						encoder.reset();
						removeAndComplete(null);
					}
				} catch (Throwable t) {
					if (debug.on()) {
						debug.log("send task exception %s", (Object) t);
					}
					// buffer cleanup: if there is an exception, the buffer
					// should appear empty for the next write as there is
					// nothing to write
					dst.position(dst.limit());
					encoder.reset();
					removeAndComplete(t);
				}
			}
			if (debug.on()) {
				debug.log("exit send task");
			}
		}

		private boolean tryCompleteWrite() throws IOException {
			if (debug.on()) {
				debug.log("enter writing");
			}
			boolean finished = false;
			loop: while (true) {
				final ChannelState ws = writeState.get();
				if (debug.on()) {
					debug.log("write state: %s", ws);
				}
				switch (ws) {
					case WAITING:
						break loop;
					case UNREGISTERED:
						if (debug.on()) {
							debug.log("registering write event");
						}
						channel.registerEvent(writeEvent);
						writeState.compareAndSet(UNREGISTERED, WAITING);
						if (debug.on()) {
							debug.log("registered write event");
						}
						break loop;
					case AVAILABLE:
						boolean written = write();
						if (written) {
							if (debug.on()) {
								debug.log("finished writing to the channel");
							}
							finished = true;
							break loop; // All done
						} else {
							writeState.compareAndSet(AVAILABLE, UNREGISTERED);
							continue loop; // Effectively "goto UNREGISTERED"
						}
					case CLOSED:
						throw new IOException("Output closed");
					default:
						throw new InternalError(String.valueOf(ws));
				}
			}
			if (debug.on()) {
				debug.log("exit writing");
			}
			return finished;
		}

		@SuppressWarnings("unchecked")
		private void removeAndComplete(Throwable error) {
			if (debug.on()) {
				debug.log("removeAndComplete error=%s", (Object) error);
			}
			queue.remove();
			if (error != null) {
				try {
					action.accept(null, error);
				} finally {
					future.completeExceptionally(error);
				}
			} else {
				try {
					action.accept(attachment, null);
				} finally {
					future.complete(attachment);
				}
			}
			encoded = false;
			firstPass = true;
			attachment = null;
			action = null;
			future = null;
		}
	}

	private class ReceiveTask extends CompleteRestartableTask {

		@Override
		public void run() {
			if (debug.on()) {
				debug.log("enter receive task");
			}
			loop: while (!receiveScheduler.isStopped()) {
				ChannelState rs = readState;
				if (data.hasRemaining()) {
					if (debug.on()) {
						debug.log("remaining bytes received %s", data.remaining());
					}
					if (!demand.isFulfilled()) {
						try {
							int oldPos = data.position();
							reader.readFrame(data, decoder);
							int newPos = data.position();
							// Reader always consumes bytes:
							assert oldPos != newPos : data;
						} catch (Throwable e) {
							receiveScheduler.stop();
							messageConsumer.onError(e);
						}
						if (!data.hasRemaining()) {
							rs = readState = UNREGISTERED;
						}
						continue;
					}
					break loop;
				}
				if (debug.on()) {
					debug.log("receive state: %s", rs);
				}
				switch (rs) {
					case WAITING:
						break loop;
					case UNREGISTERED:
						try {
							rs = readState = WAITING;
							channel.registerEvent(readEvent);
						} catch (Throwable e) {
							receiveScheduler.stop();
							messageConsumer.onError(e);
						}
						break loop;
					case AVAILABLE:
						try {
							data = channel.read();
						} catch (Throwable e) {
							receiveScheduler.stop();
							messageConsumer.onError(e);
							break loop;
						}
						if (data == null) { // EOF
							receiveScheduler.stop();
							messageConsumer.onComplete();
							break loop;
						} else if (!data.hasRemaining()) {
							// No data at the moment. Pretty much a "goto",
							// reusing the existing code path for registration
							rs = readState = UNREGISTERED;
						}
						continue loop;
					default:
						throw new InternalError(String.valueOf(rs));
				}
			}
			if (debug.on()) {
				debug.log("exit receive task");
			}
		}
	}

	private class WriteEvent implements RawChannel.RawEvent {

		@Override
		public int interestOps() {
			return SelectionKey.OP_WRITE;
		}

		@Override
		public void handle() {
			if (debug.on()) {
				debug.log("write event");
			}
			ChannelState s;
			do {
				s = writeState.get();
				if (s == CLOSED) {
					if (debug.on()) {
						debug.log("write state %s", s);
					}
					break;
				}
			} while (!writeState.compareAndSet(s, AVAILABLE));
			sendScheduler.runOrSchedule();
		}
	}

	private class ReadEvent implements RawChannel.RawEvent {

		@Override
		public int interestOps() {
			return SelectionKey.OP_READ;
		}

		@Override
		public void handle() {
			if (debug.on()) {
				debug.log("read event");
			}
			readState = AVAILABLE;
			receiveScheduler.runOrSchedule();
		}
	}
}